from hpo.hpo_base import *
import logging
import itertools
import pickle
import os
import numpy as np


logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
)


class RandomOptimizer(HyperparameterOptimizer):

    """A simple random search optimiser (with logging facilities)."""
    def __init__(self, env, max_iters: int, batch_size: int, n_repetitions: int = 1,
                 anneal_lr: bool = False,
                 log_interval: int = 1, log_dir = None, dummy=False):
        super(RandomOptimizer, self).__init__(env,
                                              max_iters=max_iters,
                                              batch_size=batch_size,
                                              n_repetitions=n_repetitions)
        self.max_iters *= n_repetitions
        self.log_interval = log_interval
        self.log_dir = log_dir
        self.dummy = dummy
        self.anneal_lr = anneal_lr

    def suggest(self, n_suggestions: int = 1) -> List:
        configs = [self.env.config_space.sample_configuration() for _ in range(n_suggestions)]
        return configs

    def run(self,):
        cur_iters = 0
        while cur_iters < self.max_iters:
            suggested_configs = self.suggest(n_suggestions=self.batch_size)
            if self.n_repetitions > 1:
                suggested_configs = list(itertools.chain.from_iterable(itertools.repeat(x, self.n_repetitions) for x in suggested_configs))
            if self.dummy:
                coeff = np.random.RandomState(0).uniform(low=-1., high=1., size=suggested_configs[0].get_array().shape[0])
                coeff2 = np.random.RandomState(1).uniform(low=-1., high=1., size=suggested_configs[0].get_array().shape[0])
                # logging.info(f'Using synthetic mode with coeff = {coeff}')

                trajectories = [np.nansum(c.get_array() * coeff) + np.nansum(coeff2 * c.get_array() ** 2) for c in suggested_configs]
            else:
                trajectories = self.env.train_batch(suggested_configs, exp_idx_start=cur_iters, anneal_lr=self.anneal_lr)
            self.X += suggested_configs
            self.y += trajectories
            cur_iters += len(suggested_configs)

            if cur_iters % self.log_interval == 0:
                print(f'HPO Iteration: {cur_iters}/{self.max_iters}. Best: {np.min(self.y)}')
                if self.log_dir is not None:
                    print(f'Saving intermediate results to {os.path.join(self.log_dir, "stats.pkl")}')
                    pickle.dump([self.X, self.y], open(os.path.join(self.log_dir, 'stats.pkl'), 'wb'))
        return self.X, self.y
